from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *

import os
import sys
import time
import datetime
import argparse

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=30, help='number of epochs')
parser.add_argument('--image_folder', type=str, default='data/samples', help='path to dataset')
parser.add_argument('--batch_size', type=int, default=16, help='size of each image batch')
parser.add_argument('--model_config_path', type=str, default='config/yolov3.cfg', help='path to model config file')
parser.add_argument('--data_config_path', type=str, default='config/coco.data', help='path to data config file')
parser.add_argument('--weights_path', type=str, default='weights/yolov3.weights', help='path to weights file')
parser.add_argument('--class_path', type=str, default='data/coco.names', help='path to class label file')
parser.add_argument('--conf_thres', type=float, default=0.8, help='object confidence threshold')
parser.add_argument('--nms_thres', type=float, default=0.4, help='iou thresshold for non-maximum suppression')
parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_size', type=int, default=416, help='size of each image dimension')
parser.add_argument('--checkpoint_interval', type=int, default=1, help='interval between saving model weights')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='directory where model checkpoints are saved')
parser.add_argument('--use_cuda', type=bool, default=True, help='whether to use cuda if available')
opt = parser.parse_args()
print(opt)

cuda = torch.cuda.is_available() and opt.use_cuda

os.makedirs('output', exist_ok=True)# 创建文件夹，如果存在那么不再创建
os.makedirs('checkpoints', exist_ok=True)

classes = load_classes(opt.class_path)# 加载class名称

# Get data configuration
data_config     = parse_data_config(opt.data_config_path)
train_path      = data_config['train']# 获取训练路径

# Get hyper parameters
hyperparams     = parse_model_config(opt.model_config_path)[0]# 第一个字典的参数即[net]中的参数
# 获取各个训练参数
learning_rate   = float(hyperparams['learning_rate'])
momentum        = float(hyperparams['momentum'])
decay           = float(hyperparams['decay'])
burn_in         = int(hyperparams['burn_in'])

# Initiate model
model = Darknet(opt.model_config_path)# 模型初始化
# model.load_weights(opt.weights_path)
model.apply(weights_init_normal)# 权重初始化

if cuda:
    model = model.cuda()

model.train()# 模型变为训练模式

# Get dataloader
dataloader = torch.utils.data.DataLoader(
    ListDataset(train_path),
    batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)# shuffle表示每次epoach是否打乱顺序
# dataloader中存储图片数据和target数据
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# 构造优化方式
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, dampening=0, weight_decay=decay)
# model.parameters()为可迭代参数，应该全部为variable
# momentum为学习率的衰减比率
# dampening为动量抑制因子
# weight_decay权重衰减(L2正则化)
for epoch in range(opt.epochs):
    # epoch表示将数据集跑多少轮
    for batch_i, (_, imgs, targets) in enumerate(dataloader):
        # pytorch中使用Variable将tensor送上计算graph上的计算节点，只有在graph上的变量才会在backward中计算
        imgs = Variable(imgs.type(Tensor))
        targets = Variable(targets.type(Tensor), requires_grad=False)# requires_grad用来判断放到graph中的tensor是否为变量

        optimizer.zero_grad()# 如果不置零，Variable的梯度在每次backward时都会累加

        loss = model(imgs, targets)# 计算loss

        loss.backward()# backward()
        optimizer.step()# 更新所有参数

        print('[Epoch %d/%d, Batch %d/%d] [Losses: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f, recall: %.5f]' %
                                    (epoch, opt.epochs, batch_i, len(dataloader),
                                    model.losses['x'], model.losses['y'], model.losses['w'],
                                    model.losses['h'], model.losses['conf'], model.losses['cls'],
                                    loss.item(), model.losses['recall']))

        model.seen += imgs.size(0)

    if epoch % opt.checkpoint_interval == 0:
        # 标定多少次存一个weights
        model.save_weights('%s/%d.weights' % (opt.checkpoint_dir, epoch))# 参数为一个路径，注意当想向函数传入一个有变量的字符串时的方法
